import os
from typing import Tuple
import gym
import numpy as np
import tqdm
from absl import app, flags
from tensorboardX import SummaryWriter
import warnings

import argparse
import torch
from vae import VAE
import utils
import d4rl
from tqdm import tqdm
import torch.nn as nn

import torch.nn.functional as F
from coolname import generate_slug
import json
from utils import get_lr
import tree

warnings.filterwarnings('ignore')
# CUDA_VISIBLE_DEVICES=3 python traj_vae_loop_otr.py --env antmaze --seed=6 --dataset=large


parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
# dataset
parser.add_argument('--env', type=str, default='hopper')
parser.add_argument('--lambda_loss', type=float, default=1.0)
parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
parser.add_argument('--version', type=str, default='v2')
parser.add_argument('--k', type=int, default=10)
parser.add_argument('--save_dir', type=str, default='./tmp/')
# model
parser.add_argument('--model', default='VAE', type=str)
parser.add_argument('--hidden_dim', type=int, default=512) 
parser.add_argument('--beta', type=float, default=0.5)
# train
parser.add_argument('--num_iters', type=int, default=int(1e5))
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', default=0.0001, type=float)
parser.add_argument('--scheduler', default=False, action='store_true')
parser.add_argument('--gamma', default=0.95, type=float)
parser.add_argument('--no_max_action', default=False, action='store_true')
parser.add_argument('--clip_to_eps', default=False, action='store_true')
parser.add_argument('--eps', default=1e-4, type=float)
parser.add_argument('--latent_dim', default=None, type=int, help="default: action_dim * 2")
parser.add_argument('--no_normalize', default=False, action='store_true', help="do not normalize states")
args = parser.parse_args()

device = 'cuda'

def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
    trajs = [[]]

    for i in tqdm(range(len(observations))):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs

def merge_trajectories(trajs):
  flat = []
  for traj in trajs:
    for transition in traj:
      flat.append(transition)
  return tree.map_structure(lambda *xs: np.stack(xs), *flat)

def qlearning_dataset_with_timeouts(env,
                                    dataset=None,
                                    terminate_on_end=False,
                                    disable_goal=True,
                                    **kwargs):
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    realdone_ = []
    if "infos/goal" in dataset:
        if not disable_goal:
            dataset["observations"] = np.concatenate(
                [dataset["observations"], dataset['infos/goal']], axis=1)
        else:
            pass

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i]
        new_obs = dataset['observations'][i + 1]
        action = dataset['actions'][i]
        reward = dataset['rewards'][i]
        done_bool = bool(dataset['terminals'][i])
        realdone_bool = bool(dataset['terminals'][i])
        if "infos/goal" in dataset:
            final_timestep = True if (dataset['infos/goal'][i] !=
                                dataset['infos/goal'][i + 1]).any() else False
        else:
            final_timestep = dataset['timeouts'][i]

        if i < N - 1:
            done_bool += final_timestep

        if (not terminate_on_end) and final_timestep:
        # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        realdone_.append(realdone_bool)
        episode_step += 1

    return {
      'observations': np.array(obs_),
      'actions': np.array(action_),
      'next_observations': np.array(next_obs_),
      'rewards': np.array(reward_)[:],
      'terminals': np.array(done_)[:],
      'realterminals': np.array(realdone_)[:],
  }

def load_trajectories(name: str, fix_antmaze_timeout=True):
    env = gym.make(name)
    if "antmaze" in name and fix_antmaze_timeout:
        dataset = qlearning_dataset_with_timeouts(env)
    else:
        dataset = d4rl.qlearning_dataset(env)
    dones_float = np.zeros_like(dataset['rewards'])

    for i in range(len(dones_float) - 1):
        if np.linalg.norm(dataset['observations'][i + 1] -
                      dataset['next_observations'][i]
                     ) > 1e-6 or dataset['terminals'][i] == 1.0:
            dones_float[i] = 1
        else:
            dones_float[i] = 0
    dones_float[-1] = 1

    if 'realterminals' in dataset:
        masks = 1.0 - dataset['realterminals'].astype(np.float32)
    else:
        masks = 1.0 - dataset['terminals'].astype(np.float32)
    
    traj = split_into_trajectories(
      observations=dataset['observations'].astype(np.float32),
      actions=dataset['actions'].astype(np.float32),
      rewards=dataset['rewards'].astype(np.float32),
      masks=masks,
      dones_float=dones_float.astype(np.float32),
      next_observations=dataset['next_observations'].astype(np.float32))
    return traj

def compute_returns(traj):
    episode_return = 0
    for transition in traj:
      episode_return += transition[2]
    return episode_return

def compute_rewards_per_step(traj, mean_center):
    score_ = []
    per_done = []
    per_ndone = []
    i = 1
    for j, traj in enumerate(trajs):
        if len(traj) > 1 :
            scores = []
            scores_ = []
            per_done_ = []
            per_ndone_ = []
            for step in range(len(traj)):
                states_1 = traj[step][0]
                actions_1 = traj[step][1]   
                train_states = torch.from_numpy(states_1).to(device)
                train_actions = torch.from_numpy(actions_1).to(device)
                _, mean1, std1 = vae(train_states, train_actions)
                pdist = nn.PairwiseDistance(p=2)
                output = pdist(mean_center,mean1)
                scores.append(output.item())

            for exp_lambda in range(1,11):
                scores_ = np.exp(-exp_lambda*np.array(scores))              
                if traj[step][2] == 1.0:
                    per_done_.append(scores_.mean())
                else:
                    per_ndone_.append(scores_.mean())
            
            if traj[step][2] == 1.0:
                per_done.append(per_done_)
            else:
                per_ndone.append(per_ndone_)
    
    ndone_array = np.mean(np.array(per_ndone), axis=0)
    ndone_number = len(np.array(per_ndone))
    done_array = np.mean(np.array(per_done), axis=0)
    done_number = len(np.array(per_done))
    return ndone_array, ndone_number, done_array, done_number
 


# train vae
env_name = f"{args.env}-{args.dataset}-{args.version}"
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
if args.no_max_action:
    max_action = None
print('state_dim:', state_dim, 'action_dim:', action_dim, 'max_action:', max_action)
latent_dim = action_dim * 2
if args.latent_dim is not None:
    latent_dim = args.latent_dim

# original dataset
replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
replay_buffer.convert_selfdata('../implicit_q_learning/datasets/datasets_cvae_full_split_1/antmaze-' + args.dataset + '-v2.hdf5')
# split expert from original dataset
name = 'antmaze-'+args.dataset+'-v2'
trajs = load_trajectories(name)
returns = [sum([t[2] for t in traj]) / (1e-4 + np.linalg.norm(traj[0][0][:2])) for traj in trajs]
idx = np.argpartition(returns, -args.k)[-args.k:]
demo_returns = [returns[i] for i in idx]
expert_demo = []
for i in idx:
    expert_demo.append(trajs[i])
expert_demo = merge_trajectories(expert_demo)    

if not args.no_normalize:
    mean, std = replay_buffer.normalize_states()
else:
    print("No normalize")
if args.clip_to_eps:
    replay_buffer.clip_to_eps(args.eps)
states = replay_buffer.state
actions = replay_buffer.action

lambda_loss_list = [0]
for i in range(1):
    lambda_loss = lambda_loss_list[i]
    # train
    if args.model == 'VAE':
        vae = VAE(state_dim, action_dim, latent_dim, max_action, hidden_dim=args.hidden_dim).to(device)
    else:
        raise NotImplementedError
    optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if args.scheduler:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.gamma)

    total_size = states.shape[0]
    batch_size = args.batch_size
    #lambda_loss = args.lambda_loss

    for step in tqdm(range(args.num_iters + 2), desc='train'):
        idx = np.random.choice(total_size, batch_size-5)
        idx_self = np.random.choice(len(expert_demo[0]), 5, replace=False)
        states_1 = list(states[idx])
        actions_1 = list(actions[idx])
        states_2 = list(expert_demo[0][idx_self])
        actions_2 = list(expert_demo[1][idx_self])
        states_t = np.array(states_1 + states_2)
        actions_t = np.array(actions_1 + actions_2)
    
        train_states = torch.from_numpy(states_t).to(device)
        train_actions = torch.from_numpy(actions_t).to(device)

        # Variational Auto-Encoder Training
        recon, mean, std = vae(train_states, train_actions)

        indices_z = torch.tensor([251, 252, 253, 254, 255]).to(device)
        sub_std = torch.index_select(std, 0, indices_z).to(device)
        sub_mean = torch.index_select(mean, 0, indices_z).to(device)
        std_loss = torch.var(sub_std, 0, unbiased=False).mean()
        mean_loss = torch.var(sub_mean, 0, unbiased=False).mean()   

        recon_loss = F.mse_loss(recon, train_actions)
        KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + args.beta * KL_loss + std_loss * lambda_loss + mean_loss * lambda_loss
        #vae_loss = recon_loss + args.beta * KL_loss
    
        optimizer.zero_grad()
        vae_loss.backward()
        optimizer.step()
    
        if step == 100000:
            torch.save(vae.state_dict(), './models/vae_model_%s_%s_%s_%s_%s.pt' % (args.env, args.dataset,args.k, lambda_loss, step))

#    # load data
#    state_dim = env.observation_space.shape[0]
#    action_dim = env.action_space.shape[0]
#    max_action = float(env.action_space.high[0])
##    latent_dim = action_dim * 2
#
#    # calculate center point of expert distribution
#    states_expert = expert_demo[0]
#    actions_expert = expert_demo[1]
#    train_states = torch.from_numpy(states_expert).to(device)
#    train_actions = torch.from_numpy(actions_expert).to(device)
#    _, mean_all, std_all = vae(train_states, train_actions)
#    mean_center = torch.mean(mean_all, 0)
#    std_center = torch.mean(std_all, 0)#
#
#    eval_results = []
#    #fmt_dict = '[\'%.8f\', \'%.8f\', \'%.8f\', \'%.8f\', \'%.8f\', \'%.8f\', \'%.8f\', \'%.8f\', \'%.8f\', \'%.8f\']'
#    ndone_array, ndone_number, done_array, done_number = compute_rewards_per_step(trajs, mean_center)
#    eval_results.append(ndone_array)
#    eval_results.append(done_array)
#    print(eval_results)
#    np.savetxt(os.path.join(args.save_dir, '%s_%s_%s.txt' % (args.env, args.dataset, lambda_loss)),
#               eval_results, fmt="%.8f", delimiter=",")
    #print(ndone_array)
    #print(done_array)
    #print(args.dataset, lambda_loss)
    #lambda_loss += 0.1

print('trajs numbers:', len(trajs))    
print('length of selected expert demo:', len(expert_demo[0]))
print(f"demo returns {demo_returns}, mean {np.mean(demo_returns)}")
#print('numbers of the ndone trajs:', ndone_number)
#print('numbers of the done trajs:', done_number)
    